# -*- coding: utf-8 -*- 
import pickle
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
import tensorflow_addons as tfa
import warnings
import matplotlib.pyplot as plt
from solve_q2 import return_enterp

warnings.filterwarnings("ignore")   #不显示警告
font1 = {'family' : 'Times New Roman',
'weight' : 'normal',
'size'   : 20,
}
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False

def plot_train_test(history):
    plt.figure(figsize=(12,4))

    plt.subplots_adjust(left=0.125, bottom=None, right=0.9, top=None,
                    wspace=0.3, hspace=None)
    plt.subplot(1,2,1)
    plt.plot(history.history['loss'],linewidth=3,label='Train')
    plt.plot(history.history['val_loss'],linewidth=3,linestyle='dashed',label='Test')
    plt.xlabel('Epoch',fontsize=20)
    plt.ylabel('loss',fontsize=20)
    plt.legend(prop=font1)

    plt.subplot(1,2,2)
    plt.plot(history.history['categorical_accuracy'],linewidth=3,label='Train')
    plt.plot(history.history['val_categorical_accuracy'],linewidth=3,linestyle='dashed',label='Test')
    plt.xlabel('Epoch',fontsize=20)
    plt.ylabel('Acc',fontsize=20)
    plt.legend(prop=font1)
    plt.show()

if __name__ == '__main__':
    data_after_clu = pickle.load(open('./model_and_data/data_after_clu.pkl','rb'))
    ener_div = pickle.load(open(r'.\model_and_data\ener_div.pkl','rb'))
    ener_one_hot = pickle.load(open(r'.\model_and_data\ener_one_hot.pkl','rb'))

    data_after_clu = data_after_clu.to_numpy(dtype=np.int32)
    ener_div = ener_div.to_numpy(dtype=np.int32)
    ener_one_hot = ener_one_hot.to_numpy(dtype=np.int32)

    X_train, X_test, y_train, y_test = train_test_split(
                                            data_after_clu,\
                                            ener_one_hot,\
                                            test_size=0.3,\
                                            random_state=0)
    rows, cols = data_after_clu.shape

#    print(y_train.shape)
#    print(y_test.shape)

    model = keras.Sequential()
    model.add(layers.Input(shape=(cols,)))
    model.add(layers.Dense(60, activation='sigmoid'))
    model.add(layers.Dense(5, activation='sigmoid'))

    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    
    model.compile(
        optimizer='adam',  # Optimizer
        # Loss function to minimize
        loss= loss_fn,
        # List of metrics to monitor
        metrics =['categorical_accuracy']
    )   
    
    filepath = r'./model_and_data/ann_model.h5'
    checkpoint = ModelCheckpoint(filepath=filepath,monitor='val_acc',
                              verbose=1,save_best_only=True,mode='max')
    callback_list=[checkpoint]
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=200,
        epochs=200,
        # We pass some validation for
        # monitoring validation loss and metrics
        # at the end of each epoch
        validation_data=(X_test, y_test),
        callbacks=callback_list,
    )


    
    plot_train_test(history)

    model.summary()

